import json

from evaluate_qa import answer_match, clean_answer


def get_sqa_question_type(question):
    question = question.lstrip()
    if question[:4].lower() == 'what':
        return 'what'
    elif question[:2].lower() == 'is':
        return 'is'
    elif question[:3].lower() == 'how':
        return 'how'
    elif question[:3].lower() == 'can':
        return 'can'
    elif question[:5].lower() == 'which':
        return 'which'
    else:
        return 'others'     # others


data = json.load(open('outputs/sqa3d/checkpoint-1500/preds.json', 'r'))
data = json.load(
    open('./Projects/embodied-generalist/logs/LEO-1_tuning_scannet/eval_results/sqa3d/results.json', 'r'))

em = {'what': 0, 'is': 0, 'how': 0, 'can': 0, 'which': 0, 'others': 0, 'avg': 0}
total = {'what': 0, 'is': 0, 'how': 0, 'can': 0, 'which': 0, 'others': 0, 'avg': 0}

for pred in data:
    # question = pred['question']

    question = pred['instruction'][1:]

    qa_type = get_sqa_question_type(question)
    print(question, qa_type)
    total[qa_type] += 1

    # em[qa_type] += pred['em']

    response_pred = clean_answer(pred['response_pred'])
    ref_captions = [clean_answer(s) for s in pred['response_gt']]

    em_flag, em_refined_flag = answer_match(response_pred, ref_captions)
    em[qa_type] += em_flag
    em['avg'] += em_flag
    total['avg'] += 1

    # em[qa_type] += pred['em']

for k, v in em.items():
    print(f"{k}: {v / total[k] if total[k] > 0 else 0:.4f} ({v} / {total[k]})")
